import json
import pkgutil
from dataclasses import dataclass, field, fields
from typing import List, Union

FRAGMENTS = json.loads(pkgutil.get_data(__name__, "fragments.json"))


@dataclass
class Config:
    """Class for keeping track of the config variables:
    "instruction_keys": list(str) - Determines which instruction_keys are used from fragments.json,
        the corresponding string will be inserted under "instruction" in the fragments. Default: [None] (No instruction)
    "cot_trigger_keys": list(str) - Determines which cot triggers are used from fragments.json,
        the corresponding string will be inserted under "cot_trigger" in the fragments. Default: ["kojima-01"]
    "answer_extraction_keys": list(str) - Determines which answer extraction prompts are used from fragments.json,
        the corresponding string will be inserted under "answer" in the fragments. Default: ["kojima-01"]
    "template_cot_generation": string - is the model input in the text generation step, variables in brackets.
        Only variables of this list are allowed: "instruction", 'question", "answer_choices", "cot_trigger"
        Default: {instruction}\\n\\n{question}\\n{answer_choices}\\n\\n{cot_trigger}
    "template_answer_extraction": string - is the model input in the answer extraction step, variables in brackets.
        Only variables of this list are allowed: "instruction", 'question", "answer_choices", "cot_trigger",
        "cot", "answer"
        Default: {instruction}\\n\\n{question}\\n{answer_choices}\\n\\n{cot_trigger}{cot}\\n{answer_extraction}
    "author" : str - Name of the person responsible for generation, Default: ""
    "api_service" str - Name of the used api service: "openai", "openai_chat", "huggingface_hub", "huggingface_endpoint" or "cohere".
        Plus a mock api service "mock_api" for debugging, Default: "huggingface_hub"
    "engine": str -  Name of model used, look at website of api which are
        available, e.g. for "openai": "text-davinci-002", Default: "google/flan-t5-xl"
    "temperature": float - Describes how much randomness is in the generated output,
        0.0 means the model will only output the most likely answer, 1.0 means
        the model will also output very unlikely answers, defaults to 0
    "max_tokens": int - Maximum length of output generated by model , Default: 256
    "api_time_interval": float - Pause between two api calls in seconds, Default: 1.0
    "warn": bool - Print warnings preventing excessive api usage, Default: True
    """

    # just saving the defaults as multiline strings for now:
    # Default:
    # '''{instruction}

    # {question}
    # {answer_choices}

    # {cot_trigger}'''

    # Default:
    #     '''{instruction}

    #     {question}
    #     {answer_choices}

    #     {cot_trigger}{cot}
    #     {answer_extraction}'''

    idx_range: Union[tuple, str, None] = "all"  # depricated
    # Passing a default list as an argument to dataclasses needs to be done with a lambda function
    # https://stackoverflow.com/questions/52063759/passing-default-list-argument-to-dataclasses
    instruction_keys: List = field(default_factory=lambda: [None])
    cot_trigger_keys: List = field(default_factory=lambda: ["kojima-01"])
    answer_extraction_keys: List = field(default_factory=lambda: ["kojima-01"])
    template_cot_generation: str = "{instruction}\n\n{question}\n{answer_choices}\n\n{cot_trigger}"
    template_answer_extraction: str = "{instruction}\n\n{question}\n{answer_choices}\n\n{cot_trigger}{cot}\n{answer_extraction}"
    author: str = ""
    api_service: str = "huggingface_hub"
    engine: str = "google/flan-t5-xl"
    temperature: Union[int, float] = 0.0
    max_tokens: int = 256
    api_time_interval: Union[int, float] = 1.0
    verbose: bool = True
    warn: bool = True
    # TODO: add a way to set the api key?

    def __post_init__(self):
        # raise error if API key is not supported
        available_endpoints = ["openai", "openai_chat", "huggingface_hub", "huggingface_endpoint", "cohere", "mock_api"]
        if self.api_service not in available_endpoints:
            raise ValueError(f"API service '{self.api_service}' not in available endpoints {available_endpoints}")
        
        # replace all keys (or non given keys) in config with the corresponding values

        # Inserts None at index 0 of instruction_keys to query without an explicit instruction
        if self.instruction_keys == "all":
            self.instruction_keys = [None] + list(FRAGMENTS["instructions"].keys())
        elif not self.instruction_keys:
            self.instruction_keys = [None]

        if self.cot_trigger_keys == "all":
            self.cot_trigger_keys = [None] + list(FRAGMENTS["cot_triggers"].keys())
        elif not self.cot_trigger_keys:
            self.cot_trigger_keys = [None]

        if self.answer_extraction_keys == "all":
            self.answer_extraction_keys = [None] + list(FRAGMENTS["answer_extractions"].keys())
        elif not self.answer_extraction_keys:
            self.answer_extraction_keys = [None]

        # turn strings into lists for all trigger keys
        if isinstance(self.instruction_keys, str):
            self.instruction_keys = [self.instruction_keys]
        if isinstance(self.cot_trigger_keys, str):
            self.cot_trigger_keys = [self.cot_trigger_keys]
        if isinstance(self.answer_extraction_keys, str):
            self.answer_extraction_keys = [self.answer_extraction_keys]

        # check if all keys are valid
        for key in self.instruction_keys:
            if key is not None and key not in FRAGMENTS["instructions"]:
                raise ValueError(f"Given instruction key '{key}' is not in fragments.json.")
        for key in self.cot_trigger_keys:
            if key is not None and key not in FRAGMENTS["cot_triggers"]:
                raise ValueError(f"Given cot_trigger key '{key}' is not in fragments.json.")
        for key in self.answer_extraction_keys:
            if key is not None and key not in FRAGMENTS["answer_extractions"]:
                raise ValueError(f"Given answer_extraction key '{key}' is not in fragments.json.")

        # check if the templates contain only allowed keys
        import re

        input_variables = re.findall("{(.*?)}", self.template_cot_generation + self.template_answer_extraction)
        allowed_variables = [
            "instruction",
            "question",
            "answer_choices",
            "cot_trigger",
            "cot",
            "answer_extraction",
        ]
        for variable in input_variables:
            if variable not in allowed_variables:
                raise ValueError(f"Given variable '{variable}' is not allowed in templates. Allowed variables are: {allowed_variables}")

        # simple checks
        if self.idx_range != "all":
            assert isinstance(self.idx_range, tuple), "idx_range must be a tuple"
            assert isinstance(self.idx_range[0], int), "idx_range must be a tuple of ints"
            assert isinstance(self.idx_range[1], int), "idx_range must be a tuple of ints"
            assert self.idx_range[0] < self.idx_range[1], "idx_range must be a tuple of ints with idx_range[0] < idx_range[1]"

        if self.instruction_keys != "all":
            assert all(isinstance(key, (str, type(None))) for key in self.instruction_keys), "instruction_keys must be a list of strings"

        if self.cot_trigger_keys != "all":
            assert all(isinstance(key, (str, type(None))) for key in self.cot_trigger_keys), "cot_trigger_keys must be a list of strings"

        if self.answer_extraction_keys != "all":
            assert all(
                isinstance(key, (str, type(None))) for key in self.answer_extraction_keys
            ), "answer_extraction_keys must be a list of strings"

        assert isinstance(self.template_cot_generation, str), "template_cot_generation must be a string"
        assert isinstance(self.template_answer_extraction, str), "template_answer_extraction must be a string"

        assert isinstance(self.author, str), "author must be a string"
        assert isinstance(self.api_service, str), "api_service must be a string"
        assert isinstance(self.engine, str), "engine must be a string"
        assert isinstance(self.temperature, (int, float)), "temperature must be a int or float"
        assert isinstance(self.max_tokens, int), "max_tokens must be an int"
        assert isinstance(self.api_time_interval, (int, float)), "api_time_interval must be a int or float"
        assert isinstance(self.verbose, bool), "verbose must be a bool"
        assert isinstance(self.warn, bool), "warn must be a bool"

    @classmethod
    def from_dict(cls, d):
        return cls(**d)

    @staticmethod
    def _all_fields():
        return [f.name for f in fields(Config)]
